I love R. I wouldn’t be where I am in my career without the tools of the Tidyverse and kindness of the Rstats and Rladies communities. It provided me a platform to entering data science, exposing me to the world of software engineering and development along the way. Choosing to use R over Stata for my Economics Honours thesis was probably the best choice I made (hides behind shield).
I picked up new programming languages over the past two years working as a data scientist. The choice of programming languages at work are determined by the design of the tech stack and the analytics teams. We mostly use PySpark at work for exploratory data analysis (EDA) and “feature engineering”. So far, I’ve managed. I picked up standard python libraries like Pandas and NumPy, machine learning libaries (MLlib) in Pyspark, and wrangled many tables in SQL.
But the one thing I couldn’t put myself through is learning how to plot in Python. I couldn’t get over the ease and beauty of building visualisations using ggplot2. I would do all my work in SQL and pyspark, but I would always finish off my EDA using ggplot.
I’ve been finally pushing myself to get acquainted with Python plotting libraries. I figured that it wouldn’t hurt being bilingual in the two main data science languages. Plus all the developers at work are all python useres (pythonistas?), so it would be jointly beneficial to communicate using the programming lingua franca.
I found it difficult navigating the vast array of packages Python has for plotting. Matplotlib, seaborn, altair, bokeh, etc. But I realised that most of the higher level modules are built on top of matplotlib. So while you may plot using a higher level package like seaborn, a lot of the tweaking is done using matplotlib functions.
Learning matplotlib is difficult. The first thing which took me a while to wrap my head around was how I needed the data to be structured to build the plots I wanted. Matplotlib requires data in a wide format, as opposed to a long (tidy) format when reading in components of the data. Seaborn was much easier to grasp as it requires long/tidy data structures when plotting. This feels more intuitive to me as a ggplot user.
I came across Kieran Healy’s blog post a couple of weeks ago where he plotted cumulative reported cases of the COVID-19 virus for the top 50 countries. I thought it would be a good exercise for me to try to replicate this in Matplotlib.
I’ve always admired this style of visualisation It preserves the variation of all the countries while highlighting the key country you want to focus on for comparison and trend analysis. I use similar plots at work all the time to assist in presentations - a clean and compact way of presenting rich datasets.
I saved myself the effort of recreating the dataset from scratch by using Kieran’s R code found in his COVID Github Repository. The COVID data comes from the European Centre for Disease Prevention and Control (ECDC). Note that I have not checked the validity or accuracy of the data source or scraping process. This is just an exercise in creating a similar figure in Python.
Most of the data preprocessing has been completed in the R script. However, there is still a bit of work requierd to get the necessary datasets and objects prepared for plotting. The exact data and scripts I used here
First, load in the data, libraries and clean up some of the country labels.
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as tick
# Setting some Matplotlib parameters
%matplotlib inline
mpl.rcParams['figure.dpi']= 300
plt.style.use('seaborn-talk')
# Replace the some of the country names with shorter labels
data = pd.read_csv('covid/bw/data/cov_case_curve2.csv')\
.replace({'cname' :
{'United States' : 'USA',
'Iran, Islamic Republic of' : 'Iran',
'Korea, Republic of' : 'South Korea',
'United Kingdom' : 'UK'}})\
.drop('Unnamed: 0', axis=1)
# Show first 5 rows of the main dataset
data.head()
Next, we subset our data so it contains the top 50 cumulative cases as of April 6th (df_50) and pivot them into a ‘wide’ format. The index of the dataframes will be the days_elapsed variable (x-axis) and the columns will contain cumulative cases for each country (y-axis, presented on log10 scale).
For ggplot and seaborn, these country columns would be their own variable to determine the facets of the plot. This makes creating the grid of a facet plot much simpler (one line of code).
## Countries - top 50 cumulative cases as of 2020-03-27
## We will use this to create a dataset with only the top 50 countries
## that we will highlight in all the subplots
top_50 = data.loc[data.groupby(["cname"])["cu_cases"].idxmax()]\
.sort_values('cu_cases', ascending = False)\
.head(50)\
.loc[:, ['iso3', 'cname', 'cu_cases']]
## Filter countries in top 50
_df = data.loc[data['iso3'].isin(top_50['iso3'])]\
## Restructure data into wide format
## Top 50
df_50 = _df.pivot(index = 'days_elapsed',
values = 'cu_cases',
columns = 'cname')\
.reset_index(inplace=False)
# Days elapsed as index (x axis)
df_50.index = df_50['days_elapsed']
# Drop unwanted columns
#df = df.drop('days_elapsed', axis = 1)
df_50 = df_50.drop('days_elapsed', axis = 1)
# Display one of the dataframes
df_50.head()
I started by plotting a single country before adding the other components of the final chart. Starting simple helped me get the basics right first. The below chart the cumulative growth of cases in China by days since the 100th case.
You can plot this using matplotlib or the plot function of the pandas dataframe. All the customisation uses matplotlib functions. The difference between using matplotlib and pandas.plot is minor in a simple plot. However, it is much easier to start off with pandas.plot then add customisation via matplotlib when it comes to more complicated visualisations (thanks to Chris Moffitt’s blog post on effective plotting in Python for this advice).
## Plot China
fig, ax = plt.subplots(figsize = (8, 13))
# Matplotlib
plt.plot(df_50.index, df_50['China'], color = 'grey')
# Pandas
# df_50['China'].plot(color='grey')
# Plot customisation
plt.title('Cumulative Growth rate of China')
plt.xlabel('Days since 100th confirmed case')
plt.ylabel('Cumulative number of cases (log10 scale)')
ax.set_xticks(np.arange(0, 80, 20), minor=False)
ax.set_yscale('log', basey=10)
ax.yaxis.set_major_formatter(tick.ScalarFormatter())
ax.get_yaxis().set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
plt.plot();
Next is to display the values of all coutries in a single figure. Matplotlib requires the code to explicitly identify which columns of the dataframe it is plotting. To do this I created a list of all countries which I loop over to plot each individual line. This will also be used to ‘highlight’ the relevant countries in each subplot later on.
The plot function in pandas is more forgiving. It will plot all columns for you automatically. No for loop needed.
# Create list of countries to loop over
countries = top_50['cname'].drop_duplicates().tolist()
# Print first 5 countries
countries[0:5]
## ['USA', 'Spain', 'Italy', 'Germany', 'China']
fig, axes= plt.subplots(figsize = (5, 8))
# Pandas
df_50.plot(color='grey', alpha = 0.6, linewidth=0.5, legend = False, ax = axes)
# Matplotlib customisation
plt.title('Cumulative reported cases for all countries')
plt.xlabel('Days since 100th confirmed case')
plt.ylabel('Cumulative number of cases (log10 scale)')
axes.set_xticks(np.arange(0, 80, 20), minor=False)
axes.set_yscale('log', basey=10)
axes.grid(alpha=0.2)
axes.yaxis.set_major_formatter(tick.ScalarFormatter())
axes.get_yaxis().set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
plt.plot();
fig, axes = plt.subplots(figsize = (5, 8))
# Matplotlib
for idx, count in enumerate(countries):
plt.plot(df_50.index, df_50[str(count)], color = 'grey', alpha = 0.6, linewidth=0.5)
# Matplotlib Customisation
plt.title('Cumulative reported cases for all countries')
plt.xlabel('Days since 100th confirmed case')
plt.ylabel('Cumulative number of cases (log10 scale)')
axes.set_xticks(np.arange(0, 80, 20), minor=False)
axes.set_yscale('log', basey=10)
axes.yaxis.set_major_formatter(tick.ScalarFormatter())
axes.get_yaxis().set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
plt.plot();
This is where things get complicated. We need to loop over the countries list for both matplotlib and pandas to create subplots for each of the top 50 countries. This means that matplotlib has another for loop to compute all the cumulative cases as well as doing it across subplots. I spent quite a few hours trying not to do this (I would’ve got this post out a lot earlier).
But it turns out the best way to do this is just plot the 50 countries using pandas.plot and use the for loop to plot across all subplots. I’ve kept my attempted matplotlib version in the tab below – please get in touch if you know a better way of doing this.
fig, axes = plt.subplots(10, 5, figsize = (16, 30), sharex = True, sharey = True)
for idx, count in enumerate(countries):
# Get grey lines for all subplots
df_50.plot(ax = axes[idx//5][idx%5],
legend = False,
color='grey',
alpha = 0.6,
linewidth=0.5)
axes[idx//5][idx%5].set_title(str(count), size = 9)
axes[idx//5][idx%5].set_xlabel('')
axes[idx//5][idx%5].set_ylabel('')
axes[idx//5][idx%5].set_yscale('log', basey=10)
axes[idx//5][idx%5].yaxis.set_major_formatter(tick.ScalarFormatter())
axes[idx//5][idx%5].get_yaxis().set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
axes[idx//5][idx%5].grid(alpha=0.1)
axes[idx//5][idx%5].set_xticks(np.arange(0, 80, 20), minor=False)
fig.suptitle('Cumulative Number of Reported Cases of COVID-19: Top 50 Countries', fontsize=20,
x=0.12, y=.91, horizontalalignment='left', verticalalignment='top')
fig.text(0.12, 0.895, 'Date of Saturday, April 4, 2020', fontsize=16, ha='left', va='top')
fig.text(0.04, 0.5, 'Cumulative number of cases (log10 scale)', va='center', rotation='vertical', size = 16)
fig.text(0.5, 0.097, 'Days since 100th confirmed case', ha='center', size = 16)
plt.figure();
fig, axes = plt.subplots(10, 5, figsize = (16, 30), sharex = True, sharey = True)
for idx, count in enumerate(countries):
for country in enumerate(countries):
axes[idx//5][idx%5].plot(df_50.index, df_50[str(country[1])],
color='grey',
alpha = 0.6,
linewidth=0.5)
axes[idx//5][idx%5].title.set_text(str(count))
axes[idx//5][idx%5].set_title(str(count), size = 9)
axes[idx//5][idx%5].set_xlabel('')
axes[idx//5][idx%5].set_ylabel('')
axes[idx//5][idx%5].set_yscale('log', basey=10)
axes[idx//5][idx%5].yaxis.set_major_formatter(tick.ScalarFormatter())
axes[idx//5][idx%5].get_yaxis().set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
axes[idx//5][idx%5].grid(alpha=0.2)
axes[idx//5][idx%5].set_xticks(np.arange(0, 80, 20), minor=False)
fig.suptitle('Cumulative Number of Reported Cases of COVID-19: Top 50 Countries', fontsize=20,
x=0.12, y=.91, horizontalalignment='left', verticalalignment='top')
fig.text(0.12, 0.895, 'Date of Saturday, April 4, 2020', fontsize=16, ha='left', va='top')
fig.text(0.04, 0.5, 'Cumulative number of cases (log10 scale)', va='center', rotation='vertical', size = 16)
fig.text(0.5, 0.097, 'Days since 100th confirmed case', ha='center', size = 16)
plt.plot();
Finally all that’s left to do is to highlight the line corresponding to the country in the subplot and add a point for the end of each line. To do plot the points at the end of each red line, I create and loop over a dataframe which contains the top 50 countries with the corresponding latest day elapsed and cumulative case. This is completed using just pandas.plot and customised with matplotlib functions.
# Subset dataframe with top 50 countries and the latest cumulative case value
markers = data.loc[data.groupby(["cname"])["cu_cases"].idxmax()]\
.sort_values('cu_cases', ascending = False)\
.head(50)\
.loc[:, ['days_elapsed', 'cname', 'cu_cases']]\
.reset_index(drop = True)
markers.head()
fig, axes = plt.subplots(10, 5, figsize = (16, 30), sharex = True, sharey = True)
for idx, count in enumerate(countries):
# Get grey lines for all subplots
df_50.plot(ax = axes[idx//5][idx%5],
legend = False,
color='grey',
alpha = 0.6,
linewidth=0.5)
# Highlight relevant countries for each subplot
df_50[str(count)].plot(ax = axes[idx//5][idx%5],
legend = False,
color='red',
linewidth=0.9)
# Add markers at the end of each line
markers.query('cname == "{}"'.format(count))\
.plot.scatter(ax = axes[idx//5][idx%5],
x='days_elapsed',
y='cu_cases',
color = 'red')
axes[idx//5][idx%5].set_title(str(count), size = 9)
axes[idx//5][idx%5].set_xlabel('')
axes[idx//5][idx%5].set_ylabel('')
axes[idx//5][idx%5].set_yscale('log', basey=10)
axes[idx//5][idx%5].yaxis.set_major_formatter(tick.ScalarFormatter())
axes[idx//5][idx%5].get_yaxis().set_major_formatter(plt.FuncFormatter(lambda x, loc: "{:,}".format(int(x))))
axes[idx//5][idx%5].grid(alpha=0.1)
axes[idx//5][idx%5].set_xticks(np.arange(0, 80, 20), minor=False)
fig.suptitle('Cumulative Number of Reported Cases of COVID-19: Top 50 Countries', fontsize=20,
x=0.12, y=.91, horizontalalignment='left', verticalalignment='top')
fig.text(0.12, 0.895, 'Date of Saturday, April 4, 2020', fontsize=16, ha='left', va='top')
fig.text(0.04, 0.5, 'Cumulative number of cases (log10 scale)', va='center', rotation='vertical', size = 16)
fig.text(0.5, 0.097, 'Days since 100th confirmed case', ha='center', size = 16)
plt.figure();
Resources I found useful as an R user learning Python:
Effectively using Matplotlib - great for outlining a set of principles and steps for plotting
Pandas comparison with R - Method chaining preserves the main functionality of the tidyverse %>% pipe, just need to look up the corresponding functions
Pandas plotting - useful summary of pandas plotting tools